{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J8y9ZkLXmAZc"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "l7u9xMl3UeOX"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8hTYmggsUAFZ"
      },
      "source": [
        "# TensorFlow 数据集\n",
        "\n",
        "TensorFlow Datasets 提供了一系列可以和 TensorFlow 配合使用的数据集。它负责下载和准备数据，以及构建 [`tf.data.Dataset`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OGw9EgE0tC0C"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/datasets/overview\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorflow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/datasets/overview.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/datasets/overview.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 Github 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/datasets/overview.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y8lL6RLLUtCO"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tJc3DDLEUAFf"
      },
      "source": [
        "## 安装\n",
        "\n",
        "`pip install tensorflow-datasets`\n",
        "\n",
        "注意使用 `tensorflow-datasets` 的前提是已经安装好 TensorFlow，目前支持的版本是 `tensorflow` (或者 `tensorflow-gpu`) >= `1.15.0`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "boeZp0sYbO41"
      },
      "outputs": [],
      "source": [
        "!pip install -q tensorflow tensorflow-datasets matplotlib"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vcwlX1pgVtbV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TTBSvHcSLBzc"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PJRFn4v-UAFl"
      },
      "source": [
        "## Eager execution\n",
        "\n",
        "TensorFlow 数据集同时兼容 TensorFlow [Eager 模式](https://tensorflow.google.cn/guide/eager) 和图模式。在这个 colab 环境里面，我们的代码将通过 Eager 模式执行。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "o9H2EiXzfNgO"
      },
      "outputs": [],
      "source": [
        "tf.compat.v1.enable_eager_execution()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mXICnLE4UAFq"
      },
      "source": [
        "## 列出可用的数据集\n",
        "\n",
        "每一个数据集（dataset）都实现了抽象基类 [`tfds.core.DatasetBuilder`](https://www.tensorflow.google.cn/datasets/api_docs/python/tfds/core/DatasetBuilder) 来构建。你可以通过 `tfds.list_builders()` 列出所有可用的数据集。\n",
        "\n",
        "你可以通过 [数据集文档页](https://www.tensorflow.google.cn/datasets/catalog/overview) 查看所有支持的数据集及其补充文档。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "FAvbSVzjLCIb"
      },
      "outputs": [],
      "source": [
        "tfds.list_builders()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JWFm6VtdUAFt"
      },
      "source": [
        "## `tfds.load`: 一行代码获取数据集\n",
        "\n",
        "[`tfds.load`](https://www.tensorflow.google.cn/datasets/api_docs/python/tfds/load) 是构建并加载 `tf.data.Dataset` 最简单的方式。\n",
        "\n",
        "`tf.data.Dataset` 是构建输入流水线的标准 TensorFlow 接口。如果你对这个接口不熟悉，我们强烈建议你阅读 [TensorFlow 官方指南](https://www.tensorflow.google.cn/guide/datasets)。\n",
        "\n",
        "下面，我们先加载 MNIST 训练数据。这个步骤会下载并准备好该数据，除非你显式指定 `download=False`。值得注意的是，一旦该数据准备好了，后续的 `load` 命令便不会重新下载，可以重复使用准备好的数据。\n",
        "你可以通过指定 `data_dir=` (默认是 `~/tensorflow_datasets/`) 来自定义数据保存/加载的路径。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "dCou80mnLLPV"
      },
      "outputs": [],
      "source": [
        "mnist_train = tfds.load(name=\"mnist\", split=\"train\")\n",
        "assert isinstance(mnist_train, tf.data.Dataset)\n",
        "print(mnist_train)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xdICgkWrUAFw"
      },
      "source": [
        "加载数据集时，将使用规范的默认版本。 但是，建议你指定要使用的数据集的主版本，并在结果中表明使用了哪个版本的数据集。更多详情请参考[数据集版本控制文档](https://github.com/tensorflow/datasets/blob/master/docs/)。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OCXCz-vhj0kE"
      },
      "outputs": [],
      "source": [
        "mnist = tfds.load(\"mnist:1.*.*\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qyhTG75PUAFy"
      },
      "source": [
        "## 特征字典\n",
        "\n",
        "所有 `tfds` 数据集都包含将特征名称映射到 Tensor 值的特征字典。 典型的数据集（如 MNIST）将具有2个键：`\"image\"` 和 `\"label\"`。 下面我们看一个例子。\n",
        "\n",
        "注意：在图模式（graph mode）下，请参阅 [tf.data 指南](https://www.tensorflow.google.cn/guide/datasets#creating_an_iterator) 以了解如何在 `tf.data.Dataset` 上进行迭代。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YHE21nkHLrER"
      },
      "outputs": [],
      "source": [
        "for mnist_example in mnist_train.take(1):  # 只取一个样本\n",
        "    image, label = mnist_example[\"image\"], mnist_example[\"label\"]\n",
        "\n",
        "    plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap(\"gray\"))\n",
        "    print(\"Label: %d\" % label.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ee5rTx4IUAF1"
      },
      "source": [
        "## `DatasetBuilder`\n",
        "\n",
        "`tfds.load` 实际上是一个基于 `DatasetBuilder` 的简单方便的包装器。我们可以直接使用 MNIST `DatasetBuilder` 实现与上述相同的操作。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9FDDJXmBhpQ4"
      },
      "outputs": [],
      "source": [
        "mnist_builder = tfds.builder(\"mnist\")\n",
        "mnist_builder.download_and_prepare()\n",
        "mnist_train = mnist_builder.as_dataset(split=\"train\")\n",
        "mnist_train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TaKrseRkUAF4"
      },
      "source": [
        "## 输入流水线\n",
        "\n",
        "一旦有了 `tf.data.Dataset` 对象，就可以使用 [`tf.data` 接口](https://www.tensorflow.google.cn/guide/datasets) 定义适合模型训练的输入流水线的其余部分。\n",
        "\n",
        "在这里，我们将重复使用数据集，以便有无限的样本流，然后随机打乱并创建大小为32的批。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9OQZqGZMlSE8"
      },
      "outputs": [],
      "source": [
        "mnist_train = mnist_train.repeat().shuffle(1024).batch(32)\n",
        "\n",
        "# prefetch 将使输入流水线可以在模型训练时异步获取批处理。\n",
        "mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "# 现在你可以遍历数据集的批次并在 mnist_train 中训练批次：\n",
        "#   ..."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4OT8i4qhUAF7"
      },
      "source": [
        "## 数据集信息（DatasetInfo）\n",
        "\n",
        "生成后，构建器将包含有关数据集的有用信息："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mSamfFznA9Ph"
      },
      "outputs": [],
      "source": [
        "info = mnist_builder.info\n",
        "print(info)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MVMLnarBUAF-"
      },
      "source": [
        "`DatasetInfo` 还包含特征相关的有用信息："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "u1wL14QH2TW1"
      },
      "outputs": [],
      "source": [
        "print(info.features)\n",
        "print(info.features[\"label\"].num_classes)\n",
        "print(info.features[\"label\"].names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YTiE3Ve9UAGB"
      },
      "source": [
        "你也可以通过 `tfds.load` 使用 `with_info = True` 加载 `DatasetInfo`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tvZYujQwBL7B"
      },
      "outputs": [],
      "source": [
        "mnist_test, info = tfds.load(\"mnist\", split=\"test\", with_info=True)\n",
        "print(info)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "skNcP7LnUAGE"
      },
      "source": [
        "## 可视化\n",
        "\n",
        "对于图像分类数据集，你可以使用 `tfds.show_examples` 来可视化一些样本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DpE2FD56cSQR"
      },
      "outputs": [],
      "source": [
        "fig = tfds.show_examples(info, mnist_test)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "overview.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
